/*
 * Based on arch/arm/kernel/traps.c
 *
 * Copyright (C) 1995-2009 Russell King
 * Copyright (C) 2012 ARM Ltd.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
#include <seminix/interrupt.h>
#include <seminix/mmap.h>
#include <seminix/uaccess.h>
#include <seminix/sched/signal.h>
#include <asm/insn.h>
#include <asm/traps.h>
#include <asm/daifflags.h>
#include <asm/system_misc.h>
#include <asm/stacktrace.h>
#include <asm/arch_timer.h>
#include <asm/debug.h>

static const char *handler[]= {
    "Synchronous Abort",
    "IRQ",
    "FIQ",
    "Error"
};

static void dump_backtrace_entry(unsigned long where)
{
    printk(" %pS\n", (void *)where);
}

static void __dump_instr(const char *lvl, struct pt_regs *regs)
{
    unsigned long addr = instruction_pointer(regs);
    char str[sizeof("00000000 ") * 5 + 2 + 1], *p = str;
    int i;

    for (i = -4; i < 1; i++) {
        unsigned int val, bad;

        bad = get_user(val, &((u32 *)addr)[i]);

        if (!bad)
            p += sprintf(p, i == 0 ? "(%08x) " : "%08x ", val);
        else {
            p += sprintf(p, "bad PC value");
            break;
        }
    }
    printk("%sCode: %s\n", lvl, str);
}

static void dump_instr(const char *lvl, struct pt_regs *regs)
{
    if (!user_mode(regs)) {
        mm_segment_t fs = get_fs();
        set_fs(KERNEL_DS);
        __dump_instr(lvl, regs);
        set_fs(fs);
    } else {
        __dump_instr(lvl, regs);
    }
}

/*
 * AArch64 PCS assigns the frame pointer to x29.
 *
 * A simple function prologue looks like this:
 * 	sub	sp, sp, #0x10
 *   	stp	x29, x30, [sp]
 *	mov	x29, sp
 *
 * A simple function epilogue looks like this:
 *	mov	sp, x29
 *	ldp	x29, x30, [sp]
 *	add	sp, sp, #0x10
 */
static int unwind_frame(struct task_struct *tsk, struct stackframe *frame)
{
    unsigned long fp = frame->fp;

    if (fp & 0xf)
        return -EINVAL;

    if (!tsk)
        tsk = current;

    if (!on_accessible_stack(tsk, fp, NULL))
        return -EINVAL;

    frame->fp = READ_ONCE(*(unsigned long *)(fp));
    frame->pc = READ_ONCE(*(unsigned long *)(fp + 8));

    /*
     * Frames created upon entry from EL0 have NULL FP and PC values, so
     * don't bother reporting these. Frames created by __noreturn functions
     * might have a valid FP even if PC is bogus, so only terminate where
     * both are NULL.
     */
    if (!frame->fp && !frame->pc)
        return -EINVAL;

    return 0;
}

void dump_backtrace(struct pt_regs *regs, struct task_struct *tsk)
{
    struct stackframe frame;
    int skip;

    pr_debug("%s(regs = %p tsk = %p)\n", __func__, regs, tsk);

    if (!tsk)
        tsk = current;

    if (!try_get_task_stack(tsk))
        return;

    if (tsk == current) {
        frame.fp = (unsigned long)__builtin_frame_address(0);
        frame.pc = (unsigned long)dump_backtrace;
    } else {
        /*
         * task blocked in __switch_to
         */
        frame.fp = thread_saved_fp(tsk);
        frame.pc = thread_saved_pc(tsk);
    }

    skip = !!regs;
    printk("Call trace:\n");
    do {
        /* skip until specified stack frame */
        if (!skip) {
            dump_backtrace_entry(frame.pc);
        } else if (frame.fp == regs->regs[29]) {
            skip = 0;
            /*
             * Mostly, this is the case where this function is
             * called in panic/abort. As exception handler's
             * stack frame does not contain the corresponding pc
             * at which an exception has taken place, use regs->pc
             * instead.
             */
            dump_backtrace_entry(regs->pc);
        }
    } while (!unwind_frame(tsk, &frame));

    put_task_stack(tsk);
}

void show_stack(struct task_struct *tsk, unsigned long *sp)
{
    dump_backtrace(NULL, tsk);
    barrier();
}

#define S_PREEMPT " PREEMPT"
#define S_SMP " SMP"

static void __die(const char *str, int err, struct pt_regs *regs)
{
    struct task_struct *tsk = current;
    static int die_counter;

    pr_emerg("Internal error: %s: %x [#%d]" S_PREEMPT S_SMP "\n",
         str, err, ++die_counter);

    __show_regs(regs);
    pr_emerg("Process %.*s (tid: %d, stack limit = 0x%p)\n",
         TASK_COMM_LEN, tsk->comm, task_pid_nr(tsk),
         end_of_stack(tsk));

    if (!user_mode(regs)) {
        dump_backtrace(regs, tsk);
        dump_instr(KERN_EMERG, regs);
    }
}

static DEFINE_RAW_SPINLOCK(die_lock);

/*
 * This function is protected against re-entrancy.
 */
void die(const char *msg, struct pt_regs *regs, int err)
{
    unsigned long flags;

    raw_spin_lock_irqsave(&die_lock, flags);

    __die(msg, err, regs);

    if (in_interrupt())
        panic("Fatal exception in interrupt");
    else
        panic("Fatal exception");

    raw_spin_unlock_irqrestore(&die_lock, flags);
    unreachable();
}

static void arm64_show_signal(int signo, const char *str)
{
    struct task_struct *tsk = current;
    unsigned int esr = tsk->thread.fault_code;
    struct pt_regs *regs = task_pt_regs(tsk);

    pr_info("%s[%d]: unhandled exception: ", tsk->comm, task_pid_nr(tsk));
    if (esr)
        pr_cont("%s, ESR 0x%08x, ", esr_get_class_string(esr), esr);
    pr_cont("%s", str);
    print_vma_addr(KERN_CONT " in ", regs->pc);
    pr_cont("\n");
    __show_regs(regs);
}

void arm64_force_sig_fault(int signo, int code, void __user *addr,
               const char *str)
{
    arm64_show_signal(signo, str);
    force_sig_fault(signo, code, addr, current);
}

void arm64_force_sig_mceerr(int code, void __user *addr, short lsb,
                const char *str)
{
    arm64_show_signal(SIGBUS, str);
    force_sig_mceerr(code, addr, lsb, current);
}

void arm64_notify_die(const char *str, struct pt_regs *regs,
              int signo, int sicode, void __user *addr,
              int err)
{
    if (user_mode(regs)) {
        WARN_ON(regs != current_pt_regs());
        current->thread.fault_address = 0;
        current->thread.fault_code = err;

        arm64_force_sig_fault(signo, sicode, addr, str);
    } else {
        die(str, regs, err);
    }
}

void arm64_skip_faulting_instruction(struct pt_regs *regs, unsigned long size)
{
    regs->pc += size;
}

static LIST_HEAD(undef_hook);
static DEFINE_RAW_SPINLOCK(undef_lock);

void register_undef_hook(struct undef_hook *hook)
{
    unsigned long flags;

    raw_spin_lock_irqsave(&undef_lock, flags);
    list_add(&hook->node, &undef_hook);
    raw_spin_unlock_irqrestore(&undef_lock, flags);
}

void unregister_undef_hook(struct undef_hook *hook)
{
    unsigned long flags;

    raw_spin_lock_irqsave(&undef_lock, flags);
    list_del(&hook->node);
    raw_spin_unlock_irqrestore(&undef_lock, flags);
}

static int call_undef_hook(struct pt_regs *regs)
{
    struct undef_hook *hook;
    unsigned long flags;
    u32 instr;
    int (*fn)(struct pt_regs *regs, u32 instr) = NULL;
    void __user *pc = (void __user *)instruction_pointer(regs);

    if (!user_mode(regs)) {
        __le32 instr_le;
        if (probe_kernel_address((__le32 *)pc, instr_le))
            goto exit;
        instr = le32_to_cpu(instr_le);
    } else {
        /* 32-bit ARM instruction */
        __le32 instr_le;
        if (get_user(instr_le, (__le32 __user *)pc))
            goto exit;
        instr = le32_to_cpu(instr_le);
    }

    raw_spin_lock_irqsave(&undef_lock, flags);
    list_for_each_entry(hook, &undef_hook, node)
        if ((instr & hook->instr_mask) == hook->instr_val &&
            (regs->pstate & hook->pstate_mask) == hook->pstate_val)
            fn = hook->fn;

    raw_spin_unlock_irqrestore(&undef_lock, flags);
exit:
    return fn ? fn(regs, instr) : 1;
}

void force_signal_inject(int signal, int code, unsigned long address)
{
    const char *desc;
    struct pt_regs *regs = current_pt_regs();

    if (WARN_ON(!user_mode(regs)))
        return;

    switch (signal) {
    case SIGILL:
        desc = "undefined instruction";
        break;
    case SIGSEGV:
        desc = "illegal memory access";
        break;
    default:
        desc = "unknown or unrecoverable error";
        break;
    }

    arm64_notify_die(desc, regs, signal, code, (void __user *)address, 0);
}

/*
 * Set up process info to signal segmentation fault - called on access error.
 */
void arm64_notify_segfault(unsigned long addr)
{
    int code;

    down_read(&current->mm->mmap_sem);
    if (find_vma(current->mm, addr) == NULL)
        code = SEGV_MAPERR;
    else
        code = SEGV_ACCERR;
    up_read(&current->mm->mmap_sem);

    force_signal_inject(SIGSEGV, code, addr);
}

asmlinkage void __exception do_undefinstr(struct pt_regs *regs)
{
    if (call_undef_hook(regs) == 0)
        return;

    BUG_ON(!user_mode(regs));
    force_signal_inject(SIGILL, ILL_ILLOPC, regs->pc);
}

#define __user_cache_maint(insn, address, res)			\
    if (address >= user_addr_max()) {			\
        res = -EFAULT;					\
    } else {						\
        asm volatile (					\
            "1:	" insn ", %1\n"			\
            "	mov	%w0, #0\n"		\
            "2:\n"					\
            "	.pushsection .fixup,\"ax\"\n"	\
            "	.align	2\n"			\
            "3:	mov	%w0, %w2\n"		\
            "	b	2b\n"			\
            "	.popsection\n"			\
            _ASM_EXTABLE(1b, 3b)			\
            : "=r" (res)				\
            : "r" (address), "i" (-EFAULT));	\
    }

static void user_cache_maint_handler(unsigned int esr, struct pt_regs *regs)
{
    unsigned long address;
    int rt = ESR_ELx_SYS64_ISS_RT(esr);
    int crm = (esr & ESR_ELx_SYS64_ISS_CRM_MASK) >> ESR_ELx_SYS64_ISS_CRM_SHIFT;
    int ret = 0;

    address = untagged_addr(pt_regs_read_reg(regs, rt));

    switch (crm) {
    case ESR_ELx_SYS64_ISS_CRM_DC_CVAU:	/* DC CVAU, gets promoted */
        __user_cache_maint("dc civac", address, ret);
        break;
    case ESR_ELx_SYS64_ISS_CRM_DC_CVAC:	/* DC CVAC, gets promoted */
        __user_cache_maint("dc civac", address, ret);
        break;
    case ESR_ELx_SYS64_ISS_CRM_DC_CVAP:	/* DC CVAP */
        __user_cache_maint("sys 3, c7, c12, 1", address, ret);
        break;
    case ESR_ELx_SYS64_ISS_CRM_DC_CIVAC:	/* DC CIVAC */
        __user_cache_maint("dc civac", address, ret);
        break;
    case ESR_ELx_SYS64_ISS_CRM_IC_IVAU:	/* IC IVAU */
        __user_cache_maint("ic ivau", address, ret);
        break;
    default:
        force_signal_inject(SIGILL, ILL_ILLOPC, regs->pc);
        return;
    }

    if (ret)
        arm64_notify_segfault(address);
    else
        arm64_skip_faulting_instruction(regs, AARCH64_INSN_SIZE);
}

static void ctr_read_handler(unsigned int esr, struct pt_regs *regs)
{
    int rt = ESR_ELx_SYS64_ISS_RT(esr);
    unsigned long val = arm64_ftr_reg_user_value(&arm64_ftr_reg_ctrel0);

    pt_regs_write_reg(regs, rt, val);

    arm64_skip_faulting_instruction(regs, AARCH64_INSN_SIZE);
}

static void cntvct_read_handler(unsigned int esr, struct pt_regs *regs)
{
    int rt = ESR_ELx_SYS64_ISS_RT(esr);

    pt_regs_write_reg(regs, rt, arch_counter_get_cntvct());
    arm64_skip_faulting_instruction(regs, AARCH64_INSN_SIZE);
}

static void cntfrq_read_handler(unsigned int esr, struct pt_regs *regs)
{
    int rt = ESR_ELx_SYS64_ISS_RT(esr);

    pt_regs_write_reg(regs, rt, arch_timer_get_rate());
    arm64_skip_faulting_instruction(regs, AARCH64_INSN_SIZE);
}

static void mrs_handler(unsigned int esr, struct pt_regs *regs)
{
    u32 sysreg, rt;

    rt = ESR_ELx_SYS64_ISS_RT(esr);
    sysreg = esr_sys64_to_sysreg(esr);

    if (do_emulate_mrs(regs, sysreg, rt) != 0)
        force_signal_inject(SIGILL, ILL_ILLOPC, regs->pc);
}

static void wfi_handler(unsigned int esr, struct pt_regs *regs)
{
    arm64_skip_faulting_instruction(regs, AARCH64_INSN_SIZE);
}

struct sys64_hook {
    unsigned int esr_mask;
    unsigned int esr_val;
    void (*handler)(unsigned int esr, struct pt_regs *regs);
};

static struct sys64_hook sys64_hooks[] = {
    {
        .esr_mask = ESR_ELx_SYS64_ISS_EL0_CACHE_OP_MASK,
        .esr_val = ESR_ELx_SYS64_ISS_EL0_CACHE_OP_VAL,
        .handler = user_cache_maint_handler,
    },
    {
        /* Trap read access to CTR_EL0 */
        .esr_mask = ESR_ELx_SYS64_ISS_SYS_OP_MASK,
        .esr_val = ESR_ELx_SYS64_ISS_SYS_CTR_READ,
        .handler = ctr_read_handler,
    },
    {
        /* Trap read access to CNTVCT_EL0 */
        .esr_mask = ESR_ELx_SYS64_ISS_SYS_OP_MASK,
        .esr_val = ESR_ELx_SYS64_ISS_SYS_CNTVCT,
        .handler = cntvct_read_handler,
    },
    {
        /* Trap read access to CNTFRQ_EL0 */
        .esr_mask = ESR_ELx_SYS64_ISS_SYS_OP_MASK,
        .esr_val = ESR_ELx_SYS64_ISS_SYS_CNTFRQ,
        .handler = cntfrq_read_handler,
    },
    {
        /* Trap read access to CPUID registers */
        .esr_mask = ESR_ELx_SYS64_ISS_SYS_MRS_OP_MASK,
        .esr_val = ESR_ELx_SYS64_ISS_SYS_MRS_OP_VAL,
        .handler = mrs_handler,
    },
    {
        /* Trap WFI instructions executed in userspace */
        .esr_mask = ESR_ELx_WFx_MASK,
        .esr_val = ESR_ELx_WFx_WFI_VAL,
        .handler = wfi_handler,
    },
    {},
};

asmlinkage void __exception do_sysinstr(unsigned int esr, struct pt_regs *regs)
{
    struct sys64_hook *hook;

    for (hook = sys64_hooks; hook->handler; hook++)
        if ((hook->esr_mask & esr) == hook->esr_val) {
            hook->handler(esr, regs);
            return;
        }

    /*
     * New SYS instructions may previously have been undefined at EL0. Fall
     * back to our usual undefined instruction handler so that we handle
     * these consistently.
     */
    do_undefinstr(regs);
}

static const char *esr_class_str[] = {
    [0 ... ESR_ELx_EC_MAX]		= "UNRECOGNIZED EC",
    [ESR_ELx_EC_UNKNOWN]		= "Unknown/Uncategorized",
    [ESR_ELx_EC_WFx]		= "WFI/WFE",
    [ESR_ELx_EC_CP15_32]		= "CP15 MCR/MRC",
    [ESR_ELx_EC_CP15_64]		= "CP15 MCRR/MRRC",
    [ESR_ELx_EC_CP14_MR]		= "CP14 MCR/MRC",
    [ESR_ELx_EC_CP14_LS]		= "CP14 LDC/STC",
    [ESR_ELx_EC_FP_ASIMD]		= "ASIMD",
    [ESR_ELx_EC_CP10_ID]		= "CP10 MRC/VMRS",
    [ESR_ELx_EC_CP14_64]		= "CP14 MCRR/MRRC",
    [ESR_ELx_EC_ILL]		= "PSTATE.IL",
    [ESR_ELx_EC_SVC32]		= "SVC (AArch32)",
    [ESR_ELx_EC_HVC32]		= "HVC (AArch32)",
    [ESR_ELx_EC_SMC32]		= "SMC (AArch32)",
    [ESR_ELx_EC_SVC64]		= "SVC (AArch64)",
    [ESR_ELx_EC_HVC64]		= "HVC (AArch64)",
    [ESR_ELx_EC_SMC64]		= "SMC (AArch64)",
    [ESR_ELx_EC_SYS64]		= "MSR/MRS (AArch64)",
    [ESR_ELx_EC_SVE]		= "SVE",
    [ESR_ELx_EC_IMP_DEF]		= "EL3 IMP DEF",
    [ESR_ELx_EC_IABT_LOW]		= "IABT (lower EL)",
    [ESR_ELx_EC_IABT_CUR]		= "IABT (current EL)",
    [ESR_ELx_EC_PC_ALIGN]		= "PC Alignment",
    [ESR_ELx_EC_DABT_LOW]		= "DABT (lower EL)",
    [ESR_ELx_EC_DABT_CUR]		= "DABT (current EL)",
    [ESR_ELx_EC_SP_ALIGN]		= "SP Alignment",
    [ESR_ELx_EC_FP_EXC32]		= "FP (AArch32)",
    [ESR_ELx_EC_FP_EXC64]		= "FP (AArch64)",
    [ESR_ELx_EC_SERROR]		= "SError",
    [ESR_ELx_EC_BREAKPT_LOW]	= "Breakpoint (lower EL)",
    [ESR_ELx_EC_BREAKPT_CUR]	= "Breakpoint (current EL)",
    [ESR_ELx_EC_SOFTSTP_LOW]	= "Software Step (lower EL)",
    [ESR_ELx_EC_SOFTSTP_CUR]	= "Software Step (current EL)",
    [ESR_ELx_EC_WATCHPT_LOW]	= "Watchpoint (lower EL)",
    [ESR_ELx_EC_WATCHPT_CUR]	= "Watchpoint (current EL)",
    [ESR_ELx_EC_BKPT32]		= "BKPT (AArch32)",
    [ESR_ELx_EC_VECTOR32]		= "Vector catch (AArch32)",
    [ESR_ELx_EC_BRK64]		= "BRK (AArch64)",
};

const char *esr_get_class_string(u32 esr)
{
    return esr_class_str[ESR_ELx_EC(esr)];
}

/*
 * bad_mode handles the impossible case in the exception vector. This is always
 * fatal.
 */
asmlinkage void bad_mode(struct pt_regs *regs, int reason, unsigned int esr)
{
    pr_crit("Bad mode in %s handler detected on CPU%d, code 0x%08x -- %s\n",
        handler[reason], smp_processor_id(), esr,
        esr_get_class_string(esr));

    local_daif_mask();
    panic("bad mode");
}

asmlinkage void bad_el0_sync(struct pt_regs *regs, int reason, unsigned int esr)
{
    void __user *pc = (void __user *)instruction_pointer(regs);

    current->thread.fault_address = 0;
    current->thread.fault_code = esr;

    arm64_force_sig_fault(SIGILL, ILL_ILLOPC, pc,
                  "Bad EL0 synchronous exception");
}

static void __noreturn arm64_serror_panic(struct pt_regs *regs, u32 esr)
{
    pr_crit("SError Interrupt on CPU%d, code 0x%08x -- %s\n",
        smp_processor_id(), esr, esr_get_class_string(esr));
    if (regs)
        __show_regs(regs);

    panic("Asynchronous SError Interrupt");

    unreachable();
}

static bool arm64_is_fatal_ras_serror(struct pt_regs *regs, unsigned int esr)
{
    u32 aet = arm64_ras_serror_get_severity(esr);

    switch (aet) {
    case ESR_ELx_AET_CE:	/* corrected error */
    case ESR_ELx_AET_UEO:	/* restartable, not yet consumed */
        /*
         * The CPU can make progress. We may take UEO again as
         * a more severe error.
         */
        return false;

    case ESR_ELx_AET_UEU:	/* Uncorrected Unrecoverable */
    case ESR_ELx_AET_UER:	/* Uncorrected Recoverable */
        /*
         * The CPU can't make progress. The exception may have
         * been imprecise.
         */
        return true;

    case ESR_ELx_AET_UC:	/* Uncontainable or Uncategorized error */
    default:
        /* Error has been silently propagated */
        arm64_serror_panic(regs, esr);
    }
}

asmlinkage void do_serror(struct pt_regs *regs, unsigned int esr)
{
    nmi_enter();

    /* non-RAS errors are not containable */
    if (!arm64_is_ras_serror(esr) || arm64_is_fatal_ras_serror(regs, esr))
        arm64_serror_panic(regs, esr);

    nmi_exit();
}

static int __init bug_handler(struct pt_regs *regs, unsigned int esr)
{
    if (user_mode(regs))
        return DBG_HOOK_ERROR;

    die("Oops - BUG", regs, 0);
    return DBG_HOOK_HANDLED;
}

/*
 * Initial handler for AArch64 BRK exceptions
 * This handler only used until debug_traps_init().
 */
int __init early_brk64(unsigned long addr, unsigned int esr,
        struct pt_regs *regs)
{
    return bug_handler(regs, esr) != DBG_HOOK_HANDLED;
}
